{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text generation with an RNN\n",
    "将使用Andrej Karpathy的《循环神经网络的不可思议的有效性》中的莎士比亚作品数据集。给定来自该数据的字符序列(“Shakespear”)，训练一个模型来预测序列中的下一个字符(“e”)。通过反复调用模型可以生成更长的文本序列。\n",
    "\n",
    "虽然有些句子合乎语法，但大多数都没有意义。模型没有学习单词的意思，但是考虑:\n",
    "* 该模型是基于字符的。当训练开始时，模型不知道如何拼写英语单词，甚至不知道单词是文本的一个单位。\n",
    "* 输出的结构类似于一个积木块，文本块通常以说话人的名字开头，所有的大写字母与数据集相似。\n",
    "* 如下所示，该模型在小批量文本(每个文本100个字符)上进行训练，并且仍然能够生成具有连贯结构的更长的文本序列。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 配置\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "import time"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 下载莎士比亚数据集\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n",
      "1122304/1115394 [==============================] - 1s 1us/step\n"
     ]
    }
   ],
   "source": [
    "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 阅读数据\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of text: 1115394 characters\n"
     ]
    }
   ],
   "source": [
    "# Read, then decode for py2 compat.\n",
    "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
    "# length of text is the number of characters in it\n",
    "print(f'Length of text: {len(text)} characters')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You are all resolved rather to die than to famish?\n",
      "\n",
      "All:\n",
      "Resolved. resolved.\n",
      "\n",
      "First Citizen:\n",
      "First, you know Caius Marcius is chief enemy to the people.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 查看前250个字符\n",
    "print(text[:250])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65 unique characters\n"
     ]
    }
   ],
   "source": [
    "# 文件中唯一的字符（将text装入set中，相当于计数文本中一共有多少个词（去重））\n",
    "vocab = sorted(set(text))\n",
    "print(f'{len(vocab)} unique characters')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 处理数据\n",
    "\n",
    "#### 向量化文本\n",
    "在训练之前，您需要将字符串转换为数字表示形式。\n",
    " tf.keras.layers.StringLookup层可以将每个字符转换为数字ID。它只需要首先将文本拆分为标记。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_texts = ['abcdefg', 'xyz']\n",
    "\n",
    "chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')\n",
    "chars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'tensorflow.keras.layers' has no attribute 'StringLookup'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32md:\\Github_code\\TensorFlow-note\\NLP\\1-使用RNN生成文本.ipynb 单元格 12\u001b[0m in \u001b[0;36m2\n\u001b[0;32m      <a href='vscode-notebook-cell:/d%3A/Github_code/TensorFlow-note/NLP/1-%E4%BD%BF%E7%94%A8RNN%E7%94%9F%E6%88%90%E6%96%87%E6%9C%AC.ipynb#X14sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39m# 现在创建tf.keras.layers.StringLookup层:\u001b[39;00m\n\u001b[1;32m----> <a href='vscode-notebook-cell:/d%3A/Github_code/TensorFlow-note/NLP/1-%E4%BD%BF%E7%94%A8RNN%E7%94%9F%E6%88%90%E6%96%87%E6%9C%AC.ipynb#X14sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m ids_from_chars \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39;49mkeras\u001b[39m.\u001b[39;49mlayers\u001b[39m.\u001b[39;49mStringLookup(\n\u001b[0;32m      <a href='vscode-notebook-cell:/d%3A/Github_code/TensorFlow-note/NLP/1-%E4%BD%BF%E7%94%A8RNN%E7%94%9F%E6%88%90%E6%96%87%E6%9C%AC.ipynb#X14sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m     vocabulary\u001b[39m=\u001b[39m\u001b[39mlist\u001b[39m(vocab), mask_token\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m)\n",
      "\u001b[1;31mAttributeError\u001b[0m: module 'tensorflow.keras.layers' has no attribute 'StringLookup'"
     ]
    }
   ],
   "source": [
    "# 现在创建tf.keras.layers.StringLookup层:\n",
    "ids_from_chars = tf.keras.layers.StringLookup(\n",
    "    vocabulary=list(vocab), mask_token=None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
